{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyReLU(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, x):\n",
    "        ctx.save_for_backward(x)\n",
    "        return x.clamp(min=0)\n",
    "    \n",
    "    def backward(ctx, dout):\n",
    "        (x, ) = ctx.saved_tensors\n",
    "        grad_x = dout.clone()\n",
    "        grad_x[x < 0] = 0\n",
    "        return grad_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cpu')\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "x = torch.randn(N, D_in, device=device)\n",
    "y = torch.randn(N, D_out, device=device)\n",
    "\n",
    "w1 = torch.randn(D_in, H, device=device, requires_grad=True)\n",
    "w2 = torch.randn(H, D_out, device=device, requires_grad=True)\n",
    "\n",
    "learning_rate = 1e-6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 32566816.0\n",
      "1 31606992.0\n",
      "2 32267054.0\n",
      "3 29542348.0\n",
      "4 22292370.0\n",
      "5 13749096.0\n",
      "6 7431736.5\n",
      "7 3932700.5\n",
      "8 2254174.0\n",
      "9 1461790.5\n",
      "10 1059908.375\n",
      "11 829822.125\n",
      "12 679965.0\n",
      "13 571655.125\n",
      "14 488197.15625\n",
      "15 421159.96875\n",
      "16 365853.15625\n",
      "17 319496.78125\n",
      "18 280254.96875\n",
      "19 246811.15625\n",
      "20 218122.53125\n",
      "21 193369.09375\n",
      "22 171905.390625\n",
      "23 153230.546875\n",
      "24 136926.171875\n",
      "25 122652.5703125\n",
      "26 110135.0703125\n",
      "27 99116.671875\n",
      "28 89381.234375\n",
      "29 80753.53125\n",
      "30 73096.1015625\n",
      "31 66277.4296875\n",
      "32 60213.0625\n",
      "33 54796.19921875\n",
      "34 49940.6796875\n",
      "35 45578.20703125\n",
      "36 41652.68359375\n",
      "37 38112.875\n",
      "38 34916.06640625\n",
      "39 32023.8046875\n",
      "40 29404.8125\n",
      "41 27032.126953125\n",
      "42 24879.421875\n",
      "43 22918.873046875\n",
      "44 21132.78125\n",
      "45 19503.986328125\n",
      "46 18015.62890625\n",
      "47 16655.00390625\n",
      "48 15410.74609375\n",
      "49 14272.515625\n",
      "50 13228.6484375\n",
      "51 12270.4775390625\n",
      "52 11391.7666015625\n",
      "53 10584.15625\n",
      "54 9840.1767578125\n",
      "55 9155.732421875\n",
      "56 8524.2119140625\n",
      "57 7941.22119140625\n",
      "58 7402.666015625\n",
      "59 6904.201171875\n",
      "60 6442.939453125\n",
      "61 6015.6376953125\n",
      "62 5619.58056640625\n",
      "63 5252.23876953125\n",
      "64 4911.490234375\n",
      "65 4594.92919921875\n",
      "66 4300.71484375\n",
      "67 4027.158203125\n",
      "68 3772.765625\n",
      "69 3536.06298828125\n",
      "70 3315.6787109375\n",
      "71 3110.164794921875\n",
      "72 2918.49365234375\n",
      "73 2739.782958984375\n",
      "74 2572.971923828125\n",
      "75 2417.155029296875\n",
      "76 2271.555419921875\n",
      "77 2135.511474609375\n",
      "78 2008.264892578125\n",
      "79 1889.399658203125\n",
      "80 1778.241455078125\n",
      "81 1674.1566162109375\n",
      "82 1576.63134765625\n",
      "83 1485.23974609375\n",
      "84 1399.5086669921875\n",
      "85 1319.0638427734375\n",
      "86 1243.5762939453125\n",
      "87 1172.740966796875\n",
      "88 1106.2015380859375\n",
      "89 1043.703369140625\n",
      "90 984.953125\n",
      "91 929.7330932617188\n",
      "92 877.8174438476562\n",
      "93 828.9539794921875\n",
      "94 782.9844360351562\n",
      "95 739.7149047851562\n",
      "96 698.9999389648438\n",
      "97 660.6436767578125\n",
      "98 624.5115356445312\n",
      "99 590.462158203125\n",
      "100 558.3854370117188\n",
      "101 528.1342163085938\n",
      "102 499.6141357421875\n",
      "103 472.7215576171875\n",
      "104 447.3496398925781\n",
      "105 423.4017333984375\n",
      "106 400.80816650390625\n",
      "107 379.476806640625\n",
      "108 359.3348693847656\n",
      "109 340.31353759765625\n",
      "110 322.34185791015625\n",
      "111 305.36865234375\n",
      "112 289.32611083984375\n",
      "113 274.16546630859375\n",
      "114 259.831787109375\n",
      "115 246.27992248535156\n",
      "116 233.4642333984375\n",
      "117 221.35447692871094\n",
      "118 209.89407348632812\n",
      "119 199.0544891357422\n",
      "120 188.79486083984375\n",
      "121 179.08470153808594\n",
      "122 169.89669799804688\n",
      "123 161.19515991210938\n",
      "124 152.95274353027344\n",
      "125 145.15097045898438\n",
      "126 137.76119995117188\n",
      "127 130.7622528076172\n",
      "128 124.12767028808594\n",
      "129 117.84204864501953\n",
      "130 111.88542175292969\n",
      "131 106.23865509033203\n",
      "132 100.88721466064453\n",
      "133 95.81237030029297\n",
      "134 91.00387573242188\n",
      "135 86.44172668457031\n",
      "136 82.1152114868164\n",
      "137 78.01409149169922\n",
      "138 74.12059020996094\n",
      "139 70.42797088623047\n",
      "140 66.92349243164062\n",
      "141 63.59975051879883\n",
      "142 60.44356155395508\n",
      "143 57.45030212402344\n",
      "144 54.60894775390625\n",
      "145 51.91147994995117\n",
      "146 49.35379409790039\n",
      "147 46.920997619628906\n",
      "148 44.61311721801758\n",
      "149 42.41971206665039\n",
      "150 40.33808135986328\n",
      "151 38.360172271728516\n",
      "152 36.483070373535156\n",
      "153 34.6990852355957\n",
      "154 33.003726959228516\n",
      "155 31.39352798461914\n",
      "156 29.862815856933594\n",
      "157 28.408647537231445\n",
      "158 27.027923583984375\n",
      "159 25.714853286743164\n",
      "160 24.46645164489746\n",
      "161 23.281692504882812\n",
      "162 22.154117584228516\n",
      "163 21.083036422729492\n",
      "164 20.063623428344727\n",
      "165 19.09441566467285\n",
      "166 18.17314910888672\n",
      "167 17.29734992980957\n",
      "168 16.464202880859375\n",
      "169 15.672750473022461\n",
      "170 14.919554710388184\n",
      "171 14.202781677246094\n",
      "172 13.521657943725586\n",
      "173 12.873580932617188\n",
      "174 12.256573677062988\n",
      "175 11.670100212097168\n",
      "176 11.111888885498047\n",
      "177 10.581042289733887\n",
      "178 10.076026916503906\n",
      "179 9.595016479492188\n",
      "180 9.1378173828125\n",
      "181 8.70247745513916\n",
      "182 8.288376808166504\n",
      "183 7.894074440002441\n",
      "184 7.5189948081970215\n",
      "185 7.1617431640625\n",
      "186 6.822150230407715\n",
      "187 6.4979777336120605\n",
      "188 6.190332412719727\n",
      "189 5.897026062011719\n",
      "190 5.6179680824279785\n",
      "191 5.352107524871826\n",
      "192 5.099226474761963\n",
      "193 4.858354091644287\n",
      "194 4.6288161277771\n",
      "195 4.410462856292725\n",
      "196 4.202491283416748\n",
      "197 4.004575729370117\n",
      "198 3.8159286975860596\n",
      "199 3.6360220909118652\n",
      "200 3.465054750442505\n",
      "201 3.302112579345703\n",
      "202 3.146897077560425\n",
      "203 2.9991250038146973\n",
      "204 2.8582763671875\n",
      "205 2.724318504333496\n",
      "206 2.596428871154785\n",
      "207 2.4748165607452393\n",
      "208 2.3588247299194336\n",
      "209 2.248490810394287\n",
      "210 2.143165111541748\n",
      "211 2.0429115295410156\n",
      "212 1.9472779035568237\n",
      "213 1.8564372062683105\n",
      "214 1.7696771621704102\n",
      "215 1.6870956420898438\n",
      "216 1.6083924770355225\n",
      "217 1.533337116241455\n",
      "218 1.4619852304458618\n",
      "219 1.3936939239501953\n",
      "220 1.3287529945373535\n",
      "221 1.2669289112091064\n",
      "222 1.2079746723175049\n",
      "223 1.1518067121505737\n",
      "224 1.0981680154800415\n",
      "225 1.047122597694397\n",
      "226 0.9984019994735718\n",
      "227 0.9520925879478455\n",
      "228 0.9079430103302002\n",
      "229 0.8657414317131042\n",
      "230 0.8256281018257141\n",
      "231 0.7873245477676392\n",
      "232 0.7508416175842285\n",
      "233 0.7160287499427795\n",
      "234 0.6829271912574768\n",
      "235 0.6513044834136963\n",
      "236 0.6211795806884766\n",
      "237 0.5923462510108948\n",
      "238 0.5649361610412598\n",
      "239 0.5388739705085754\n",
      "240 0.5139422416687012\n",
      "241 0.4901299774646759\n",
      "242 0.4675593078136444\n",
      "243 0.44589683413505554\n",
      "244 0.425376832485199\n",
      "245 0.40577971935272217\n",
      "246 0.3870429992675781\n",
      "247 0.36921805143356323\n",
      "248 0.3521740436553955\n",
      "249 0.3359357714653015\n",
      "250 0.3204563856124878\n",
      "251 0.3056580126285553\n",
      "252 0.2916043698787689\n",
      "253 0.2781946361064911\n",
      "254 0.2654664218425751\n",
      "255 0.2532251477241516\n",
      "256 0.24154134094715118\n",
      "257 0.23041194677352905\n",
      "258 0.21985158324241638\n",
      "259 0.20971593260765076\n",
      "260 0.20012278854846954\n",
      "261 0.19088880717754364\n",
      "262 0.1821020096540451\n",
      "263 0.17374663054943085\n",
      "264 0.16579103469848633\n",
      "265 0.15818609297275543\n",
      "266 0.15093542635440826\n",
      "267 0.14401832222938538\n",
      "268 0.13745172321796417\n",
      "269 0.13118043541908264\n",
      "270 0.1251453310251236\n",
      "271 0.11940431594848633\n",
      "272 0.11394285410642624\n",
      "273 0.10871677845716476\n",
      "274 0.10376051813364029\n",
      "275 0.09899488836526871\n",
      "276 0.09446416795253754\n",
      "277 0.09015574306249619\n",
      "278 0.08602941036224365\n",
      "279 0.08209697157144547\n",
      "280 0.07834314554929733\n",
      "281 0.07477746903896332\n",
      "282 0.07136337459087372\n",
      "283 0.06809526681900024\n",
      "284 0.06500241160392761\n",
      "285 0.062031593173742294\n",
      "286 0.05918314307928085\n",
      "287 0.05651196837425232\n",
      "288 0.053927380591630936\n",
      "289 0.051472995430231094\n",
      "290 0.04913382604718208\n",
      "291 0.046899113804101944\n",
      "292 0.044746607542037964\n",
      "293 0.042726241052150726\n",
      "294 0.04079301655292511\n",
      "295 0.038929618895053864\n",
      "296 0.0371667854487896\n",
      "297 0.03546525165438652\n",
      "298 0.03387352451682091\n",
      "299 0.032335810363292694\n",
      "300 0.030854608863592148\n",
      "301 0.02945767715573311\n",
      "302 0.028129708021879196\n",
      "303 0.02684634178876877\n",
      "304 0.02561594918370247\n",
      "305 0.024477742612361908\n",
      "306 0.023382028564810753\n",
      "307 0.02232476696372032\n",
      "308 0.021318187937140465\n",
      "309 0.020350759848952293\n",
      "310 0.019442763179540634\n",
      "311 0.018574893474578857\n",
      "312 0.017738476395606995\n",
      "313 0.01694534160196781\n",
      "314 0.016187164932489395\n",
      "315 0.015461066737771034\n",
      "316 0.014767187647521496\n",
      "317 0.014103109017014503\n",
      "318 0.013471228070557117\n",
      "319 0.01287179347127676\n",
      "320 0.012306129559874535\n",
      "321 0.011758225969970226\n",
      "322 0.011238849721848965\n",
      "323 0.010741827078163624\n",
      "324 0.010269575752317905\n",
      "325 0.009815417230129242\n",
      "326 0.009382710792124271\n",
      "327 0.008970407769083977\n",
      "328 0.008576283231377602\n",
      "329 0.008197213523089886\n",
      "330 0.007835159078240395\n",
      "331 0.007504821754992008\n",
      "332 0.00717319967225194\n",
      "333 0.006858321838080883\n",
      "334 0.006560681387782097\n",
      "335 0.006279774941504002\n",
      "336 0.006011297460645437\n",
      "337 0.0057520451955497265\n",
      "338 0.005505464971065521\n",
      "339 0.00527121452614665\n",
      "340 0.0050480980426073074\n",
      "341 0.004838684573769569\n",
      "342 0.0046399435959756374\n",
      "343 0.004441185388714075\n",
      "344 0.004254056606441736\n",
      "345 0.004070057068020105\n",
      "346 0.00390408793464303\n",
      "347 0.0037409267388284206\n",
      "348 0.003586284816265106\n",
      "349 0.0034404911566525698\n",
      "350 0.0032996756490319967\n",
      "351 0.003167463466525078\n",
      "352 0.003038448514416814\n",
      "353 0.0029158343095332384\n",
      "354 0.002799950074404478\n",
      "355 0.002685422543436289\n",
      "356 0.0025785211473703384\n",
      "357 0.002475819317623973\n",
      "358 0.00237587234005332\n",
      "359 0.0022848956286907196\n",
      "360 0.002194733591750264\n",
      "361 0.0021088372450321913\n",
      "362 0.002027371432632208\n",
      "363 0.001949044526554644\n",
      "364 0.0018720009829849005\n",
      "365 0.001800810219720006\n",
      "366 0.0017366157844662666\n",
      "367 0.0016722317086532712\n",
      "368 0.0016112936427816749\n",
      "369 0.0015495707048103213\n",
      "370 0.0014929062454029918\n",
      "371 0.0014394106110557914\n",
      "372 0.0013863315107300878\n",
      "373 0.0013362084282562137\n",
      "374 0.001290175598114729\n",
      "375 0.0012451931834220886\n",
      "376 0.0012011214857921004\n",
      "377 0.0011581802973523736\n",
      "378 0.0011177081614732742\n",
      "379 0.0010782030876725912\n",
      "380 0.0010426051449030638\n",
      "381 0.0010047854157164693\n",
      "382 0.0009737872751429677\n",
      "383 0.0009405862656421959\n",
      "384 0.0009098178124986589\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "385 0.0008801647927612066\n",
      "386 0.0008516847155988216\n",
      "387 0.0008249807870015502\n",
      "388 0.0007977974601089954\n",
      "389 0.0007707941113039851\n",
      "390 0.0007455244776792824\n",
      "391 0.0007218625396490097\n",
      "392 0.0006998440367169678\n",
      "393 0.0006773574859835207\n",
      "394 0.0006558149470947683\n",
      "395 0.000635328353382647\n",
      "396 0.0006163360667414963\n",
      "397 0.000597754551563412\n",
      "398 0.0005801969091407955\n",
      "399 0.0005621873424388468\n",
      "400 0.0005437657819129527\n",
      "401 0.000528572010807693\n",
      "402 0.0005118754925206304\n",
      "403 0.0004973375471308827\n",
      "404 0.00048320487258024514\n",
      "405 0.00046797608956694603\n",
      "406 0.00045506394235417247\n",
      "407 0.00044234763481654227\n",
      "408 0.0004304910253267735\n",
      "409 0.0004183158453088254\n",
      "410 0.0004057384794577956\n",
      "411 0.0003943214542232454\n",
      "412 0.0003833928203675896\n",
      "413 0.0003737035149242729\n",
      "414 0.0003615999303292483\n",
      "415 0.0003530292888171971\n",
      "416 0.00034304294968023896\n",
      "417 0.0003332844644319266\n",
      "418 0.0003247404529247433\n",
      "419 0.0003172496217302978\n",
      "420 0.000308716029394418\n",
      "421 0.00030055566458031535\n",
      "422 0.0002935499942395836\n",
      "423 0.0002862757246475667\n",
      "424 0.00027853803476318717\n",
      "425 0.0002719420299399644\n",
      "426 0.00026495131896808743\n",
      "427 0.000258590211160481\n",
      "428 0.00025203500990755856\n",
      "429 0.0002459651732351631\n",
      "430 0.00024028428015299141\n",
      "431 0.00023505794524680823\n",
      "432 0.00022915180306881666\n",
      "433 0.0002235878346255049\n",
      "434 0.00021810260659549385\n",
      "435 0.00021307003044057637\n",
      "436 0.0002078826364595443\n",
      "437 0.0002032561897067353\n",
      "438 0.00019876616715919226\n",
      "439 0.00019445817451924086\n",
      "440 0.00019077336764894426\n",
      "441 0.000185669501661323\n",
      "442 0.00018117106810677797\n",
      "443 0.00017758551985025406\n",
      "444 0.0001737493003020063\n",
      "445 0.00017031835159286857\n",
      "446 0.0001663821458350867\n",
      "447 0.00016280787531286478\n",
      "448 0.00015998267917893827\n",
      "449 0.00015622921637259424\n",
      "450 0.00015237941988743842\n",
      "451 0.0001497977355029434\n",
      "452 0.00014639509026892483\n",
      "453 0.0001443110522814095\n",
      "454 0.00014109072799328715\n",
      "455 0.00013853158452548087\n",
      "456 0.00013539094652514905\n",
      "457 0.00013275237870402634\n",
      "458 0.00013020244659855962\n",
      "459 0.00012780865654349327\n",
      "460 0.00012543442426249385\n",
      "461 0.0001229849294759333\n",
      "462 0.00012059471919201314\n",
      "463 0.00011828235437860712\n",
      "464 0.00011587249900912866\n",
      "465 0.00011364651436451823\n",
      "466 0.00011115205415990204\n",
      "467 0.0001091132071451284\n",
      "468 0.0001070458092726767\n",
      "469 0.00010540820949245244\n",
      "470 0.0001033720254781656\n",
      "471 0.00010145317355636507\n",
      "472 9.953746484825388e-05\n",
      "473 9.816699457587674e-05\n",
      "474 9.629638225305825e-05\n",
      "475 9.466023038839921e-05\n",
      "476 9.302741091232747e-05\n",
      "477 9.168006363324821e-05\n",
      "478 8.975407399702817e-05\n",
      "479 8.831321611069143e-05\n",
      "480 8.669614908285439e-05\n",
      "481 8.530083869118243e-05\n",
      "482 8.372577576665208e-05\n",
      "483 8.263878407888114e-05\n",
      "484 8.077006350504234e-05\n",
      "485 7.965923578012735e-05\n",
      "486 7.831626135157421e-05\n",
      "487 7.710434874752536e-05\n",
      "488 7.570152229163796e-05\n",
      "489 7.476107566617429e-05\n",
      "490 7.340348383877426e-05\n",
      "491 7.20209936844185e-05\n",
      "492 7.124034891603515e-05\n",
      "493 7.042902143439278e-05\n",
      "494 6.911814853083342e-05\n",
      "495 6.843823211966082e-05\n",
      "496 6.752501212758943e-05\n",
      "497 6.622353248530999e-05\n",
      "498 6.513982953038067e-05\n",
      "499 6.436094554373994e-05\n"
     ]
    }
   ],
   "source": [
    "for t in range(500):\n",
    "    y_pred = MyReLU.apply(x.mm(w1)).mm(w2)\n",
    "    loss = (y_pred - y).pow(2).sum()\n",
    "    print(t, loss.item())\n",
    "    loss.backward()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        w1 -= learning_rate * w1.grad\n",
    "        w2 -= learning_rate * w2.grad\n",
    "\n",
    "        w1.grad.zero_()\n",
    "        w2.grad.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
